{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# <center>基于Bert的高细粒度命名实体识别模型</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part I. 模型构建与数据集加载"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入相关库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import torch.utils.data as torchData\n",
    "\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from loguru import logger\n",
    "\n",
    "import json\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载bert模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_pretrained_bert as bert\n",
    "\n",
    "bert_model_dir = \"../bert_model/bert-chinese/\"\n",
    "\n",
    "tokenizer = bert.BertTokenizer.from_pretrained(bert_model_dir)\n",
    "\n",
    "bert = bert.BertModel.from_pretrained(bert_model_dir )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-20 14:14:32.124 | INFO     | __main__:<module>:5 - 处理 : \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "个数: 10748\n",
      "序列样例:浙商银行企业信贷部叶老桂博士则从另一个角度对五道门槛进行了解读。叶老桂认为，对目前国内商业银行而言，\n",
      "标签样例:['B-company', 'I-company', 'I-company', 'I-company', 'O', 'O', 'O', 'O', 'O', 'B-name', 'I-name', 'I-name', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
      "span样例:[[[9, 11], 'name'], [[0, 3], 'company']]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-20 14:14:32.574 | INFO     | __main__:<module>:71 - 写入文件成功:\n",
      "2020-02-20 14:14:32.575 | INFO     | __main__:<module>:5 - 处理 : \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "个数: 1343\n",
      "序列样例:彭小军认为，国内银行现在走的是台湾的发卡模式，先通过跑马圈地再在圈的地里面选择客户，\n",
      "标签样例:['B-name', 'I-name', 'I-name', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-address', 'I-address', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
      "span样例:[[[15, 16], 'address'], [[0, 2], 'name']]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-20 14:14:32.665 | INFO     | __main__:<module>:71 - 写入文件成功:\n"
     ]
    }
   ],
   "source": [
    "root_dir = \"D:\\\\nlpCorpus\\\\CLUE\\\\dataset\\\\NER\\\\\"\n",
    "\n",
    "for subset in [\"train\",\"dev\"]:\n",
    "\n",
    "    logger.info(\"处理 : \",subset)\n",
    "    \n",
    "    seqs = []\n",
    "    \n",
    "    labels = []\n",
    "    \n",
    "    spans = []\n",
    "    \n",
    "    with open(root_dir + subset + \".json\",\"r\",encoding = \"utf-8\") as f :\n",
    "\n",
    "        data = []\n",
    "\n",
    "        for line in f.readlines():\n",
    "\n",
    "            line = json.loads(line)\n",
    "\n",
    "            text = line[\"text\"]\n",
    "            \n",
    "            text_labelseq = [\"O\" for i in range(len(text))]\n",
    "            \n",
    "            text_labels = line[\"label\"]\n",
    "            \n",
    "            seq_span = []\n",
    "            \n",
    "            for NR_label in text_labels:\n",
    "                \n",
    "                seq_spans = list(text_labels[NR_label].values())[0]\n",
    "                \n",
    "                for span in seq_spans:\n",
    "                    \n",
    "                    seq_span.append([span,NR_label])\n",
    "                    \n",
    "                    for j in range(span[0],span[1] + 1):\n",
    "                   \n",
    "                        if j == span[0]:\n",
    "                        \n",
    "                            text_labelseq[j] = \"B-\"+ NR_label\n",
    "                            \n",
    "                        else:\n",
    "                            \n",
    "                            text_labelseq[j] = \"I-\"+ NR_label\n",
    "\n",
    "            seqs.append(text)\n",
    "            \n",
    "            labels.append(text_labelseq)\n",
    "            \n",
    "            spans.append(seq_span)\n",
    "    print(\"个数:\",len(seqs))\n",
    "    \n",
    "    print(\"序列样例:\"+str(seqs[0]))\n",
    "    \n",
    "    print(\"标签样例:\"+str(labels[0]))\n",
    "    \n",
    "    print(\"span样例:\"+str(spans[0]))\n",
    "    with open(subset + \"_seqs.json\",'w',encoding = \"utf-8\") as f:\n",
    "\n",
    "        f.write(json.dumps(seqs))\n",
    "\n",
    "    with open(subset + \"_labels.json\",'w',encoding = \"utf-8\") as f:\n",
    "\n",
    "        f.write(json.dumps(labels))\n",
    "        \n",
    "    with open(subset + \"_spans.json\",\"w\",encoding = \"utf-8\") as f:\n",
    "        \n",
    "        f.write(json.dumps(spans))\n",
    "        \n",
    "    logger.info(\"写入文件成功:\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建命名实体识别模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "class Bert_NER_Model(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, n_hidden,labelnum):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        self.bert = bert\n",
    "\n",
    "        self.output = torch.nn.Sequential(\n",
    "\n",
    "            torch.nn.Linear(768 , n_hidden),\n",
    "\n",
    "            torch.nn.Tanh(),\n",
    "\n",
    "            torch.nn.Dropout(0.2),\n",
    "\n",
    "            torch.nn.Linear(n_hidden,labelnum),\n",
    "\n",
    "        )\n",
    "\n",
    "        self.prob = torch.nn.Softmax(dim = -1 )\n",
    "        \n",
    "    def forward(self, b_seqs):\n",
    "\n",
    "        layers = self.bert(b_seqs, output_all_encoded_layers= False)[0]\n",
    "        \n",
    "        output = self.output(layers)\n",
    "\n",
    "        prob = self.prob(output)\n",
    "        \n",
    "        return prob"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'O': 0, 'B-address': 1, 'I-address': 2, 'B-book': 3, 'I-book': 4, 'B-company': 5, 'I-company': 6, 'B-game': 7, 'I-game': 8, 'B-government': 9, 'I-government': 10, 'B-movie': 11, 'I-movie': 12, 'B-name': 13, 'I-name': 14, 'B-organization': 15, 'I-organization': 16, 'B-position': 17, 'I-position': 18, 'B-scene': 19, 'I-scene': 20}\n"
     ]
    }
   ],
   "source": [
    "labels = {\n",
    "    \"address\":1,\n",
    "    \"book\":2,\n",
    "    \"company\":3,\n",
    "    \"game\":4,\n",
    "    \"government\":5,\n",
    "    \"movie\":6,\n",
    "    \"name\":7,\n",
    "    \"organization\":8,\n",
    "    \"position\":9,\n",
    "    \"scene\":10\n",
    "}\n",
    "ne_labels = {\"O\":0}\n",
    "count = 1\n",
    "for label in labels:\n",
    "    for head in ['B-',\"I-\"]:\n",
    "        ne_labels[head + label] = count \n",
    "        count += 1\n",
    "print(ne_labels)\n",
    "\n",
    "class NER_DataSet(torchData.Dataset):\n",
    "    \n",
    "    def __init__(self,subset = \"train\",num = - 1):\n",
    "\n",
    "        self.seqs = []\n",
    "\n",
    "        self.labels = []\n",
    "\n",
    "        if subset.lower() in [\"train\",\"dev\"]:\n",
    "            \n",
    "            with open(subset+\"_seqs.json\",\"r\",encoding = \"utf-8\") as f :\n",
    "\n",
    "                self.seqs = list(map(lambda seq:[char for char in seq],json.loads(f.read())))[:num]\n",
    "                \n",
    "            with open(subset+\"_labels.json\",\"r\",encoding = \"utf-8\") as f :\n",
    "                \n",
    "                self.labels = list(map(lambda seq_label: [ne_labels[label] for label in seq_label],json.loads(f.read())))[:num]\n",
    "                         \n",
    "            with open(subset+\"_spans.json\",\"r\",encoding = \"utf-8\") as f :\n",
    "\n",
    "                self.spans = json.loads(f.read())[:num]\n",
    "                \n",
    "        len_seqs = [len(x) for x in self.seqs]\n",
    "        \n",
    "        self.MAX_SEQ_LEN = max(len_seqs)\n",
    "        \n",
    "        self.labelNum =sum(len_seqs)\n",
    "        \n",
    "        self.IB_Num = sum([sum([label > 0 for label in seq_label]) for seq_label in self.labels] )\n",
    "        \n",
    "        self.O_Num = self.labelNum - self.IB_Num\n",
    "        \n",
    "        self.IB_weight = round(self.O_Num/self.IB_Num,2)\n",
    "        \n",
    "        self.O_weight = round(self.IB_Num/self.O_Num,2)\n",
    "        \n",
    "        self.seqs = list(map(lambda seq : [\"[CLS]\"] + seq + [\"[PAD]\"] * (self.MAX_SEQ_LEN - len(seq)),self.seqs))\n",
    "\n",
    "        self.labels = list(map(lambda seq : [-1] + seq + [-1] * (self.MAX_SEQ_LEN - len(seq)),self.labels))\n",
    "        \n",
    "        self.NE_num = sum([len(span) for span in self.spans]) \n",
    "        \n",
    "        logger.info(\"加载数据集\"+subset+\" token长度: \"+str(self.labelNum)+ \" 序列个数: \"+str(len(self.seqs))+\" 标签个数:\"+str(len(self.labels)) + \" 实体个数:\" + str(self.NE_num))\n",
    "        \n",
    "        logger.debug(\"O-label / IB-label: \" + str(self.IB_weight) + \" O-NUM : \" + str(self.O_Num) + \" IB-NUM: \" + str(self.IB_Num))\n",
    "\n",
    "    def __len__(self):\n",
    "\n",
    "        return len(self.seqs)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "\n",
    "        x = seq2ids(self.seqs[index])\n",
    "\n",
    "        y = torch.tensor(self.labels[index]).T.cuda()\n",
    "        \n",
    "        span = self.spans[index]\n",
    "        \n",
    "        return (x,y,span)\n",
    "def seq2ids(seq):\n",
    "    \n",
    "    ids = []\n",
    "    \n",
    "    for char in seq:\n",
    "        \n",
    "        char_replace = {\n",
    "            8220 :\"\\\"\",\n",
    "            8221:\"\\\"\",\n",
    "            8212:\"-\",\n",
    "            8230: \"...\",\n",
    "            8216: \"＇\",\n",
    "            8217:\"＇\",\n",
    "        }\n",
    "        if char != \"[CLS]\" and char != \"[PAD]\":\n",
    "            \n",
    "            if ord(char ) >= 65 and ord(char) <= 90:\n",
    "                char = char.lower()\n",
    "\n",
    "            if ord(char) in char_replace:\n",
    "\n",
    "                char = char_replace[ord(char)]\n",
    "\n",
    "        try:\n",
    "            \n",
    "            ids += tokenizer.convert_tokens_to_ids([char])\n",
    "            \n",
    "        except:\n",
    "            \n",
    "            ids += tokenizer.convert_tokens_to_ids([\"[UNK]\"])\n",
    "            \n",
    "\n",
    "    return ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实例化模型与数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-20 14:14:36.408 | INFO     | __main__:__init__:63 - 加载数据集train token长度: 7268 序列个数: 200 标签个数:200 实体个数:342\n",
      "2020-02-20 14:14:36.410 | DEBUG    | __main__:__init__:65 - O-label / IB-label: 3.89 O-NUM : 5782 IB-NUM: 1486\n",
      "2020-02-20 14:14:36.441 | INFO     | __main__:__init__:63 - 加载数据集dev token长度: 7550 序列个数: 200 标签个数:200 实体个数:347\n",
      "2020-02-20 14:14:36.442 | DEBUG    | __main__:__init__:65 - O-label / IB-label: 3.89 O-NUM : 6006 IB-NUM: 1544\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "权重 tensor([1.0000, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900,\n",
      "        3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900, 3.8900,\n",
      "        3.8900, 3.8900, 3.8900])\n"
     ]
    }
   ],
   "source": [
    "model = Bert_NER_Model(300,21)\n",
    "\n",
    "model = torch.nn.DataParallel(model).cuda()\n",
    "\n",
    "optimizer = torch.optim.Adam([{'params': model.parameters()}, ], lr=1e-2)\n",
    "\n",
    "trainingSet = NER_DataSet(subset = \"train\",num = 200)\n",
    "\n",
    "devSet = NER_DataSet(subset = \"dev\",num = 200)\n",
    "\n",
    "weight =  torch.FloatTensor([trainingSet.IB_weight for i in range(21)])\n",
    "\n",
    "weight[0] = 1\n",
    "\n",
    "print(\"权重\",weight)\n",
    "\n",
    "loss = torch.nn.CrossEntropyLoss(weight = weight.cuda(),ignore_index = -1,reduction = \"sum\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part II. 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "ne_labels_list = list(ne_labels.keys())\n",
    "def parse_predict(padded_predict : torch.tensor,label : torch.tensor):\n",
    "    '''\n",
    "    根据预测的标签序列解析出预测出的实体\n",
    "    '''\n",
    "\n",
    "    parsed = []\n",
    "\n",
    "    predict = [ ]\n",
    "    \n",
    "    for i in range(len(padded_predict)):\n",
    "        \n",
    "        if label[i].item() != -1 :\n",
    "            \n",
    "            predict.append(padded_predict[i].item())\n",
    "           \n",
    "    for i in range(len(predict)):\n",
    "        \n",
    "        predict_label = predict[i]\n",
    "        \n",
    "        if predict_label % 2 == 1:\n",
    "            \n",
    "            entity_span_start = i \n",
    "            \n",
    "            entity_span_end = i\n",
    "            \n",
    "            entity_type = ne_labels_list[predict_label][2:]\n",
    "            \n",
    "            for j in range(i + 1, len(predict)):\n",
    "                \n",
    "                next_predict_label = predict[j]\n",
    "                \n",
    "                if next_predict_label == predict_label + 1 :\n",
    "                \n",
    "                    entity_span_end  = j\n",
    "                    \n",
    "                if next_predict_label == 0 or next_predict_label == predict_label :\n",
    "                    \n",
    "                    break\n",
    "            parsed.append([[entity_span_start,entity_span_end],entity_type])\n",
    "\n",
    "    return parsed\n",
    "\n",
    "def measure(pred,label,span):\n",
    "    \n",
    "    parsed = parse_predict(pred,label)\n",
    "    \n",
    "    accu = 0\n",
    "    \n",
    "    for pred_span in parsed:\n",
    "    \n",
    "        if pred_span in span :\n",
    "            \n",
    "            accu += 1\n",
    "            \n",
    "    return {\"parsed_num\":len(parsed),\"accu\":accu,\"exists\":len(span),\"parsed\":parsed}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " epoch 0 : Loss :155.88104095458985 ,timeUsage:7.109553337097168,recall:0.0,parsed_num : 390,accu_num : 0\n",
      "label accu  0.6877 entity accu 0.0  O-label-ratio: 0.7955421023665382\n",
      "performance on dev set : loss: 157.5501 accu rate:0,recall : 0.0,parsed_num : 0,accu_num : 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\python36\\lib\\site-packages\\torch\\serialization.py:360: UserWarning: Couldn't retrieve source code for container of type Bert_NER_Model. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " epoch 1 : Loss :151.64886596679688 ,timeUsage:7.292750835418701,recall:0.0,parsed_num : 0,accu_num : 0\n",
      "label accu  0.7955 entity accu 0  O-label-ratio: 0.7955421023665382\n",
      "performance on dev set : loss: 157.5501 accu rate:0,recall : 0.0,parsed_num : 0,accu_num : 0\n",
      " epoch 2 : Loss :151.64886047363282 ,timeUsage:8.268913984298706,recall:0.0,parsed_num : 0,accu_num : 0\n",
      "label accu  0.7955 entity accu 0  O-label-ratio: 0.7955421023665382\n",
      "performance on dev set : loss: 157.5501 accu rate:0,recall : 0.0,parsed_num : 0,accu_num : 0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-8-98196872fca8>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     67\u001b[0m         \u001b[0mb_loss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     68\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 69\u001b[1;33m         \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     70\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     71\u001b[0m         \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"\\r epoch : {},process : {} ,timeUsage:{}\"\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mprocess\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtime_usage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mend\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mflush\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32md:\\python36\\lib\\site-packages\\torch\\optim\\adam.py\u001b[0m in \u001b[0;36mstep\u001b[1;34m(self, closure)\u001b[0m\n\u001b[0;32m    105\u001b[0m                 \u001b[0mstep_size\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'lr'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    106\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 107\u001b[1;33m                 \u001b[0mp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maddcdiv_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mstep_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexp_avg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdenom\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    109\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\"\n",
    "\n",
    "def collate_fn(data):\n",
    "    \n",
    "        b_seq,b_target,b_span =zip(*data)\n",
    "        \n",
    "        b_seq = torch.tensor(b_seq)\n",
    "\n",
    "        b_target = torch.stack(b_target,dim = 0)\n",
    "        \n",
    "        return b_seq,b_target,b_span\n",
    "    \n",
    "trainDataGenerator = torchData.DataLoader(dataset = trainingSet,batch_size = 15,shuffle = True,collate_fn = collate_fn)\n",
    "\n",
    "devDataGenerator = torchData.DataLoader(dataset = devSet,batch_size = 15,shuffle = True,collate_fn = collate_fn)\n",
    "\n",
    "history = []\n",
    "\n",
    "for epoch in range(50):\n",
    "\n",
    "    epoch_loss = [ ]\n",
    "\n",
    "    start_time = time.time()\n",
    "\n",
    "    used_seq = 0\n",
    "    \n",
    "    accu = 0\n",
    "    \n",
    "    parsed_num = 0\n",
    "    \n",
    "    label_accu = 0 \n",
    "    \n",
    "    for b_seq,b_target,b_span in trainDataGenerator:\n",
    "\n",
    "        time_usage = time.time() - start_time\n",
    "         \n",
    "        used_seq += len(b_seq)\n",
    "\n",
    "        process = str(round( used_seq / len(trainingSet) * 100 , 6)) + \"%\"\n",
    "        \n",
    "        b_prob = model(b_seq)\n",
    "\n",
    "        b_predict = b_prob.argmax(dim = 2)\n",
    "\n",
    "        label_accu += torch.sum(torch.eq(b_target,b_predict)).item()\n",
    "        \n",
    "        shape = list(b_prob.size())\n",
    "        \n",
    "        b_prob = b_prob.view(shape[0] * shape[1] ,shape[2])\n",
    "\n",
    "        b_loss = loss(b_prob,b_target.view(shape[0] * shape[1]))\n",
    "\n",
    "        epoch_loss.append(b_loss.item())\n",
    "\n",
    "        b_measured = list(measure(b_predict[idx],b_target[idx],b_span[idx]) for idx in range(len(b_predict)))\n",
    "\n",
    "        for measured in b_measured:\n",
    "                        \n",
    "            parsed_num += measured[\"parsed_num\"]\n",
    "            \n",
    "            accu += measured[\"accu\"]\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        b_loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        \n",
    "        print(\"\\r epoch : {},process : {} ,timeUsage:{}\".format(epoch,process,time_usage),end = \"\",flush=True)\n",
    "    \n",
    "    recall = round( accu / trainingSet.NE_num,2)# 召回率 = 识别出的正确实体数 / 样本的实体数\n",
    "    \n",
    "    accuracy = round(accu / parsed_num if parsed_num > 0 else 0 ,2) #正确率 = 识别出的正确实体数 / 识别出的实体数\n",
    "    \n",
    "    label_accu = round(label_accu/trainingSet.labelNum,4)\n",
    "    \n",
    "    epoch_loss =  sum(epoch_loss)/len(trainingSet)\n",
    "    \n",
    "    print(\"\\r epoch {} : Loss :{} ,timeUsage:{},recall:{},parsed_num : {},accu_num : {}\".format(epoch,epoch_loss, time_usage,recall,parsed_num,accu))\n",
    "\n",
    "    print(\"label accu \",label_accu,\"entity accu\",accuracy,\" O-label-ratio:\",trainingSet.O_Num/trainingSet.labelNum )\n",
    "    \n",
    "    test_loss = []\n",
    "    \n",
    "    test_parsed_num = 0\n",
    "    \n",
    "    test_accu_num = 0\n",
    "    \n",
    "    for b_seq,b_target,b_span in devDataGenerator:\n",
    "        \n",
    "        b_prob = model(b_seq)\n",
    "\n",
    "        b_predict = b_prob.argmax(dim = 2)\n",
    "                \n",
    "        shape = list(b_prob.size())\n",
    "        \n",
    "        b_prob = b_prob.view(shape[0] * shape[1] ,shape[2])\n",
    "\n",
    "        b_loss = loss(b_prob,b_target.view(shape[0] * shape[1])).item()\n",
    "\n",
    "        test_loss.append(b_loss)\n",
    "\n",
    "        b_measured = list(measure(b_predict[idx],b_target[idx],b_span[idx]) for idx in range(len(b_predict)))\n",
    "\n",
    "        for measured in b_measured:\n",
    "                        \n",
    "            test_parsed_num += measured[\"parsed_num\"]\n",
    "            \n",
    "            test_accu_num += measured[\"accu\"]\n",
    "                \n",
    "    test_recall = round( test_accu_num / devSet.NE_num,2)\n",
    "    \n",
    "    test_accuracy = round(test_accu_num / test_parsed_num if test_parsed_num > 0 else 0 ,2)\n",
    "    \n",
    "    test_loss = round(sum(test_loss) / len(devSet),4)\n",
    "    \n",
    "    history.append({\"train\":[accu,parsed_num,epoch_loss],\"test\":[test_accu_num,test_parsed_num,test_loss]})\n",
    "    \n",
    "    print(\"performance on dev set : loss:\",test_loss,\"accu rate:{},recall : {},parsed_num : {},accu_num : {}\".format(test_accuracy,test_recall,test_parsed_num,test_accu_num))\n",
    "\n",
    "    if epoch % 5 == 0:\n",
    "        \n",
    "        torch.save(model, \"NER_MODEL_EPOCH_\" + str(epoch))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "a = torch.tensor(\n",
    "[\n",
    "    [\n",
    "        [0,0],[1,1],[2,2]\n",
    "    ],\n",
    "    [\n",
    "        [3,3],[4,4],[5,5],\n",
    "\n",
    "    ],  \n",
    "    [\n",
    "        [7,7],[8,8],[5,5],\n",
    "\n",
    "    ],  \n",
    "])\n",
    "b =torch.tensor(\n",
    "[\n",
    "    [\n",
    "        [0,0],[1,1],[2,2]\n",
    "    ],\n",
    "    [\n",
    "        [3,3],[4,4],[5,5],\n",
    "\n",
    "    ],  \n",
    "    [\n",
    "        [7,7],[8,8],[5,5],\n",
    "\n",
    "    ],  \n",
    "])\n",
    "print(b > 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
